Skip to content

Add Granite 4.1 Vision (granite4_vision)#45597

Open
artem-spector wants to merge 48 commits intohuggingface:mainfrom
artem-spector:add-gv41
Open

Add Granite 4.1 Vision (granite4_vision)#45597
artem-spector wants to merge 48 commits intohuggingface:mainfrom
artem-spector:add-gv41

Conversation

@artem-spector
Copy link
Copy Markdown

What does this PR do?

Adds built-in support for Granite 4.1 Vision (granite4_vision), IBM's multimodal vision-language model for enterprise document understanding.

Architecture highlights

  • Vision encoder: SigLIP2 (google/siglip2-so400m-patch16-384), tiled 384×384 patches
  • Window Q-Former projector: 4×4 patch windows compressed to 2×2 query tokens via cross-attention (downsample_rate="4/8")
  • DeepStack feature injection: 8 vision-to-LLM injection points across two mechanisms:
    • LayerDeepstack: features from 4 vision encoder depths injected at 4 LLM layers (reversed order — deepest vision → earliest LLM)
    • SpatialDeepstack: deepest features split into 4 spatial offset groups (TL/TR/BL/BR), injected at 4 later LLM layers
  • Language model: GraniteForCausalLM (3.5B) with a rank-256 LoRA adapter (same-repo, LM-only)

Files added

File Purpose
modular_granite4_vision.py Source of truth — inherits from LLaVA-Next, overrides novel components
configuration_granite4_vision.py Config (generated)
modeling_granite4_vision.py Model (generated)
processing_granite4_vision.py Unified processor (generated)
image_processing_granite4_vision.py Torchvision-based image processor
image_processing_pil_granite4_vision.py PIL/NumPy image processor
tests/models/granite4_vision/ Modeling, image processing, and processor tests
docs/source/en/model_doc/granite4_vision.md Model documentation

Auto-registration

  • Config: auto-generated via configuration_granite4_vision.py model_type
  • Modeling: MODEL_MAPPING_NAMES + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
  • Processing + image processing: registered in respective auto files

Tests

  • Unit tests pass locally (pytest tests/models/granite4_vision/ -x -q)
  • @slow integration tests load real checkpoint and assert outputs within tolerance
  • make style and make check-repo pass (3 remaining failures are pre-existing upstream issues: mlinter version mismatch and Sam3Lite incomplete model)

Before submitting

  • This PR is not a duplicate
  • I have read the contributor guidelines
  • The documentation reflects the changes
  • The tests pass

Related

@artem-spector
Copy link
Copy Markdown
Author

artem-spector commented Apr 23, 2026

I've traced the root cause of the check_repository_consistency and tests_torch failures to a specific upstream commit:

[Sam3LiteText] Remove unnecessary modules/configs (#45535) (7439ac0)

This commit removed Sam3LiteTextViTConfig and Sam3LiteTextVisionConfig from the modeling file but left them referenced in auto_mappings.py, causing:

  • AttributeError: module transformers has no attribute Sam3LiteTextViTConfig (357 test failures)
  • check_repo failure: Sam3LiteTextVisionConfig appears in CONFIG_MAPPING_NAMES but is not defined

This is reproducible on main independently of our PR.

Question for reviewers: Should we include a fix for this in our PR (removing the stale entries from auto_mappings.py), or would you prefer to handle it in a separate hotfix? Happy to do either.

@artem-spector
Copy link
Copy Markdown
Author

Opened a dedicated issue for the upstream regression: #45600

@zucchini-nlp
Copy link
Copy Markdown
Member

LMK when ready for review, and ig this PR supersedes #45350?

@artem-spector
Copy link
Copy Markdown
Author

artem-spector commented Apr 25, 2026

@zucchini-nlp - yes, this PR supersedes #45350. Its our team that is responsible for producing/release IBM vision models.
This PR is ready for review from my side.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@artem-spector great usage of modular!

Seems like the model uses granite llm as backbone with deepstack features. We will need to add an llm class in that case, since calling each backbone layer manually doesn't align well with our API. We can use modular to copy everything except for a single forward

As per adapters, can you explain how the weights are released? I am not really sure we have to manually add merge_adapters, prob I can suggest a cleaner way

Comment on lines +41 to +48
```bibtex
@misc{granite-vision-4.1-4b,
title={Granite Vision 4.1},
author={IBM Granite Vision Team},
year={2026},
url={https://huggingface.co/ibm-granite/granite-vision-4.1-4b}
}
```
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we dont need a bibtext entry and as long as there is a link to HF papers/arxiv, that is enought

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed.

Comment on lines +66 to +68
device=0,
torch_dtype=torch.bfloat16,
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these two are by default "auto" so we dont need to manually set

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed.


processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, torch_dtype is "auto" by default and can be deleted

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed.

Comment on lines +165 to +177
## Notes

- The model includes LoRA adapters. Call `model.merge_lora_adapters()` after loading to merge them into base weights for faster inference.

- Set `padding_side="left"` during batched generation for more accurate results.

```py
processor.tokenizer.padding_side = "left"
```

- The model supports specialized task tags for document extraction: `<chart2csv>`, `<chart2summary>`, `<chart2code>`, `<tables_html>`, `<tables_otsl>`, `<tables_json>`. Pass these as the text prompt along with a document image.

- For key-value pair extraction, provide a JSON schema describing the fields to extract. The model returns structured JSON matching the schema.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets move this block as Usage Tips section, before the usage example code snippets

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — moved to a "Usage Tips" section before the code examples.

@@ -0,0 +1,155 @@
import math
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following "one-model - one-file" philosophy, it is better put inside modular/modeling files

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — downsampling_granite4_vision.py deleted, all contents inlined into the modular.

Comment on lines 81 to 87
"openai-privacy-filter": "OpenAIPrivacyFilterConfig",
"lasr": "LasrCTCConfig",
"wav2vec2-with-lm": "Wav2Vec2Config",
"granite4-vision": "Granite4VisionConfig",
"hy-v3": "HYV3Config",
"slanet": "SLANetConfig",
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few bad rebases :)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed the stale entries introduced by bad rebases.

Comment thread src/transformers/conversion_mapping.py Outdated
Comment on lines +463 to +466
WeightRenaming(
source_patterns=r"(vision_tower\.)vision_model\.",
target_patterns=r"\1",
),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not needed anymore, we added PrefixWeights recently and fixed all llava models

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed the granite4_vision entry from conversion_mapping.py.

@@ -0,0 +1,253 @@
# Copyright 2025 IBM. All rights reserved.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2026 :)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — Copyright 2026 IBM and The HuggingFace Team.

Comment on lines +50 to +57
class Granite4VisionModelTester(VLMModelTester):
base_model_class = Granite4VisionModel
config_class = Granite4VisionConfig
conditional_generation_class = Granite4VisionForConditionalGeneration
text_config_class = GraniteConfig
vision_config_class = CLIPVisionConfig

def __init__(self, parent, **kwargs):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need only this tester, since processing is identical to llava-next. Thanks for using VLMTester 🤩

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed test_image_processing_granite4_vision.py entirely (processing is identical to LlavaNext, no re-definition needed).

Comment thread utils/check_repo.py Outdated
Comment on lines +551 to +557
"granite4_vision",
"falcon3",
"megatron_gpt2",
"code_llama",
"hy_v3",
"openai_privacy_filter",
"slanet",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also bad rebase

@artem-spector
Copy link
Copy Markdown
Author

@zucchini-nlp I'm ready for the second round :)

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @artem-spector ! Great work, glad to see a singlr modular file.

Left some comments on further cleaning-up, and a few questions. I just noticed that all image placeholders are filled with zeros, and not sure if that is intended. If we actually have no image features scattered, can we stop adding that many placeholders without quality degradation? Dummy and unnecessary token ids increase total length and looks like a waste of resources

model_id = "ibm-granite/granite-vision-4.1-4b"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(model_id).eval()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: [...], device_map="auto") without eval, shoudl be already in eval mode when loaded

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed .eval().

("glm_image", {"pil": "GlmImageImageProcessorPil", "torchvision": "GlmImageImageProcessor"}),
("glpn", {"pil": "GLPNImageProcessorPil", "torchvision": "GLPNImageProcessor"}),
("got_ocr2", {"pil": "GotOcr2ImageProcessorPil", "torchvision": "GotOcr2ImageProcessor"}),
("granite4_vision", {"pil": "Granite4VisionImageProcessorPil", "torchvision": "Granite4VisionImageProcessor"}),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can delete this, already mapped to 'LlavaNext'

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed.

@@ -0,0 +1,734 @@
# Copyright 2025 IBM. All rights reserved.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ig you forgot to commit and push 😄



@dataclass
class Granite4VisionImageFeaturesOutput(ModelOutput):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think same as qwen3_vl.BaseModelOutputWithDeepstackFeatures

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — Granite4VisionImageFeaturesOutput now inherits BaseModelOutputWithPooling (same approach as qwen3_vl.BaseModelOutputWithDeepstackFeatures), with a deepstack_features field added.

Comment on lines +95 to +97
class Granite4VisionTextConfig(PreTrainedConfig):
model_type = "granite4_vision_text"
base_config_key = "text_config"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I am following, are we not supposed to inherit from GraniteConfig? Current class has no attributes

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — Granite4VisionTextConfig inherits GraniteConfig directly. It has no additional attributes because the text config is fully specified by GraniteConfig; the subclass exists only to set model_type = "granite4_vision_text" and base_config_key = "text_config".

past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
deepstack_features=outputs.deepstack_features,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only diff from llava-next is returning deepstack_features?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes — the main differences from LlavaNextForConditionalGeneration are: (1) the model uses Granite4VisionModel which has deepstack injection, (2) logits are scaled by text_config.logits_scaling, and (3) deepstack_features is threaded through the output. Everything else is inherited.

logits_to_keep=logits_to_keep,
**kwargs,
)
model_inputs = self._init_hybrid_cache(**model_inputs)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be able to delete this line to inti cache. Correct cache should be init by generationMixin._prepare_cache_for_generation automatically

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed; GenerationMixin._prepare_cache_for_generation handles cache initialization.

Comment on lines +140 to +146
@unittest.skip("Granite4VisionImageFeaturesOutput has no hidden_states field")
def test_get_image_features_hidden_states(self):
pass

@unittest.skip("Granite4VisionImageFeaturesOutput has no attentions field")
def test_get_image_features_attentions(self):
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two should be fixed when we add **kwargs and return a BaseModelOutputWithDeepstackFeatures

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — Granite4VisionImageFeaturesOutput now inherits BaseModelOutputWithPooling so all five tests pass. skip_test_image_features_output_shape = True remains because last_hidden_state isn't meaningful for this output type, but hidden_states, pooler_output, and field presence checks all pass.

Comment on lines +148 to +154
@unittest.skip("Base model forward returns ModelOutputWithPast, not CausalLMOutput with loss")
def test_training(self):
pass

@unittest.skip("QFormer submodules not initialized by init_weights from meta device")
def test_can_init_all_missing_weights(self):
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these also should be fixable, I don't think this is a valid reason to skip

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done for test_training — the skip was stale; Granite4VisionForConditionalGeneration computes a loss, and the framework skips Granite4VisionModel automatically via MODEL_MAPPING_NAMES.\n\ntest_can_init_all_missing_weights remains skipped: Blip2QFormerModel submodules aren't initialized from meta device by our _init_weightsBlip2QFormerPreTrainedModel._init_weights only handles Blip2ForConditionalGeneration instances, not the standalone Blip2QFormerModel. Happy to investigate further if you'd like.

Comment thread utils/check_config_attributes.py Outdated
Comment on lines +102 to +105
"Granite4VisionConfig": [
"multimodal_projector_bias",
"projector_hidden_act",
],
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if actually not used, lets just delete them from config class. When inheriting config, you can define

multimodal_projector_bias = AtributeErro()

and it will not copy this field

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — multimodal_projector_bias and projector_hidden_act are shadowed with AttributeError() at class level on Granite4VisionConfig, so the framework sees them as intentionally inaccessible. The SPECIAL_CASES_TO_ALLOW entry was removed.

@artem-spector
Copy link
Copy Markdown
Author

Hi @zucchini-nlp, sorry for delay - I was chasing a performance problem, which turn out to be not real.

Summary of changes:

  • Granite4VisionImageFeaturesOutput now inherits BaseModelOutputWithPooling instead of bare ModelOutput, which unblocks the common image-features tests
  • Added @capture_outputs to Granite4VisionTextModel.forward so output_hidden_states propagates correctly through generation
  • Removed the manual DynamicCache init from forward — GenerationMixin._prepare_cache_for_generation handles this
  • get_image_features now properly passes output_hidden_states through (via kwarg or config fallback)
  • qformer_config is now fully specified at init time rather than patched post-super().init()
  • _can_record_outputs and _deepstack_inject moved up to Granite4VisionPreTrainedModel
  • One-letter variable names renamed for clarity; unused inherited config attrs now raise AttributeError with a clear message

Open question — test_can_init_all_missing_weights:

This test verifies that every weight can be re-initialized from meta device via _init_weights. It fails because WindowQFormerDownsampler embeds a Blip2QFormerModel (a full PreTrainedModel from a different family). When from_pretrained(None, state_dict={}) walks submodules, our _init_weights has no isinstance branches for QFormer's internal layers (attention, embeddings, etc.), so they stay uninitialized.
Is there a recommended approach here, or is skipping acceptable for embedded third-party PreTrainedModel submodules?

@zucchini-nlp
Copy link
Copy Markdown
Member

When from_pretrained(None, state_dict={}) walks submodules, our _init_weights has no isinstance branches for QFormer's internal layers (attention, embeddings, etc.), so they stay uninitialized.

That is weird, usually our initializer walks down each sub-module and initializes its weights. Unless Qformer has special buffers that need to be init, it should be fine. I can check it on Monday and do another review pass

@artem-spector
Copy link
Copy Markdown
Author

When from_pretrained(None, state_dict={}) walks submodules, our _init_weights has no isinstance branches for QFormer's internal layers (attention, embeddings, etc.), so they stay uninitialized.

That is weird, usually our initializer walks down each sub-module and initializes its weights. Unless Qformer has special buffers that need to be init, it should be fine. I can check it on Monday and do another review pass

will re-check too

artemspector and others added 11 commits May 3, 2026 10:10
Full implementation of IBM Granite 4.1 Vision as a built-in HF model:
- Modular implementation (modular_granite4_vision.py)
- Generated files: config, modeling, image processing, processing
- Auto-registration: config, modeling, processing, image processing
- Tests: modeling (unit + @slow), image processor, processor
- Documentation (docs/source/en/model_doc/granite4_vision.md)
- WeightRenaming to handle SiglipVisionModel vision_model. nesting

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Upstream moved CONFIG_MAPPING_NAMES to auto_mappings.py. Add
granite4_vision entry there; resolve leftover conflict markers in
configuration_auto.py (granite4_vision is already in modeling_auto.py
and processing_auto.py).

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
- Remove granite4_vision from MISSING_IMAGE_PROCESSOR_MAPPING_NAMES (auto-discovered via TorchvisionBackend/PilBackend)
- Add granite4-vision to HARDCODED_CONFIG_FOR_MODELS in auto_docstring.py
- Add granite4_vision to DOC_MODEL_NAMES_NOT_IN_AUTO in check_repo.py
- Fix import sort in models/__init__.py and test file
- Regenerate auto_mappings.py via check_auto.py --fix_and_overwrite
- Add dates to granite4_vision.md

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Signed-off-by: artemspector <[email protected]>
Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Signed-off-by: artemspector <[email protected]>
- Fix processing_auto.py sort order (sort_auto_mappings)
- Add hy-v3, openai-privacy-filter, slanet to HARDCODED_CONFIG_FOR_MODELS
- Add hy_v3, openai_privacy_filter, slanet to DOC_MODEL_NAMES_NOT_IN_AUTO
  (new upstream models missing from these registries)

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Signed-off-by: artemspector <[email protected]>
sam3_vision_model and sam3_vit_model were incorrectly mapped to
Sam3LiteTextVisionConfig/Sam3LiteTextViTConfig instead of
Sam3VisionConfig/Sam3ViTConfig (and sam3_lite_text module instead of sam3).
These are unrelated to granite4_vision; restoring upstream/main values.

Signed-off-by: artemspector <[email protected]>
…ebase regeneration

These three upstream model entries were accidentally removed from CONFIG_MAPPING_NAMES
in auto_mappings.py by a previous run of check_auto.py --fix_and_overwrite during
an incomplete rebase state. Restoring verbatim from upstream/main.

Signed-off-by: artemspector <[email protected]>
artemspector and others added 21 commits May 3, 2026 10:10
- Item 20: drop get_image_token_mask override, use parent's get_placeholder_mask
- Item 29: delete test_image_processing_granite4_vision.py (identical to LlavaNext)

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
- Pass output_attentions/output_hidden_states explicitly to language_model
  in Granite4VisionModel.forward (were swallowed as explicit params, not
  forwarded via **kwargs)
- Collect all_hidden_states and all_self_attns in Granite4VisionTextModel
  layer loop; add output_attentions/output_hidden_states params
- Fix qformer_config dict→object conversion to run before super().__post_init__()
  so _attn_implementation.setter doesn't hit a raw dict during sub_configs iteration
- Use Blip2QFormerConfig directly in sub_configs (instead of AutoConfig) so
  save/load round-trip resolves the type correctly; add missing import to
  generated configuration_granite4_vision.py

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…es AutoConfig

blip_2_qformer is registered in CONFIG_MAPPING so AutoConfig resolves it correctly.
Moving the Blip2QFormerConfig import inside __post_init__ avoids a cross-model
top-level import that the modular converter drops from the generated file.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…ing public classes

- IGNORE_NON_TESTED + IGNORE_NON_AUTO_CONFIGURED: Granite4VisionTextModel is an
  internal subcomponent tested implicitly through Granite4VisionModel
- Doc: add autodoc entries for Granite4VisionTextConfig, Granite4VisionTextModel,
  Granite4VisionImageProcessor, Granite4VisionImageProcessorPil

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Needed for ruff F821 (undefined name) to pass under make style.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…isionModel.forward

Aligns with reviewer feedback: these args are not needed in the explicit
signature since they flow through kwargs: Unpack[TransformersKwargs].

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
- TRF010: add @strict to Granite4VisionTextConfig (direct PreTrainedConfig subclass)
- TRF002: set base_model_prefix = "model" on Granite4VisionTextModel (was "")
- TRF009: add trf-ignore comment on Blip2QFormerModel lazy import
  (cross-model import is intentional — QFormer is a shared building block)

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
- Reorder imports in modular to satisfy ruff isort (stdlib → third-party → first-party)
- Sync processing_granite4_vision.py to match converter output
  (BatchFeature from feature_extraction_utils, no model_type on processor)

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
downsampling_granite4_vision.py, image_processing_granite4_vision.py, and
image_processing_pil_granite4_vision.py are regenerated by the converter but
were previously intentionally deleted: image processors delegate to LlavaNext
(registered in image_processing_auto.py), and downsampling is inlined in
modular/modeling.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…ctor naming

- Granite4VisionTextConfig restored as proper GraniteConfig subclass
- Granite4VisionConfig.__post_init__: convert dict->config before super() so
  _attn_implementation.setter sees config objects; patch vision-size fields after super()
- Use CONFIG_MAPPING/AutoModel at module top-level (no lazy imports)
- Add _can_record_outputs to Granite4VisionTextModel for hidden_states/attentions
  capture via @capture_outputs decorator
- Add Granite4VisionTextAttention/TextDecoderLayer stubs in modular so converter
  generates the registry entries pointing to the correct layer classes
- WindowQFormerDownsampler renamed to Granite4VisionWindowQFormerDownsampler
- interpolate_downsample/spatial_offset_downsample take explicit size args (not config)
- Remove output_attentions/output_hidden_states from forward signatures (handled by
  @capture_outputs and **kwargs); use BaseModelOutputWithPast return type
- Remove prepare_inputs_for_generation (handled by parent)
- Remove _init_hybrid_cache (GraniteMoeHybrid leftover from 4.0)
- auto_mappings.py: use LlavaNextImageProcessor(Pil) instead of model-specific copies
- docs: remove .eval() from example

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Signed-off-by: artemspector <[email protected]>
…onfig attrs

- _windowed_raster/unwindowed_raster: x -> features, x_win -> windowed_features
- __init__: q, w -> query_side_str, window_side_str
- Granite4VisionConfig: shadow LlavaNextConfig's multimodal_projector_bias and
  projector_hidden_act with AttributeError() so check_config_attributes passes
  without SPECIAL_CASES_TO_ALLOW entry
- Remove Granite4VisionConfig from check_config_attributes.py SPECIAL_CASES_TO_ALLOW

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Signed-off-by: artemspector <[email protected]>
…d patching

Peek at vision_config.hidden_size (or its dict equivalent) before super() and
include hidden_size, num_attention_heads, encoder_hidden_size directly in the
CONFIG_MAPPING["blip_2_qformer"]() constructor call. This avoids mutating the
config object after super().__post_init__() runs.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Signed-off-by: artemspector <[email protected]>
…ainedModel

Follows qwen3_vl pattern: _can_record_outputs and _deepstack_inject belong on
the shared PreTrainedModel base class, not on TextModel. TextAttention/
TextDecoderLayer stubs are defined before PreTrainedModel so the converter
generates locally-scoped classes for the _can_record_outputs registry.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Signed-off-by: artemspector <[email protected]>
…cCache

- Granite4VisionImageFeaturesOutput now inherits BaseModelOutputWithPooling so
  the common test framework can introspect last_hidden_state/pooler_output/
  hidden_states/attentions fields (removes 5 test skips)
- Add @capture_outputs to Granite4VisionTextModel.forward so output_hidden_states
  is collected via hooks and propagated through to the causal LM output (fixes
  test_assisted_decoding_matches_greedy_search)
- get_image_features: populate hidden_states from vision tower when
  output_hidden_states=True (via kwarg or config); removes test skip
- Remove stale test_training skip (ForConditionalGeneration computes loss;
  base model is skipped automatically by MODEL_MAPPING_NAMES check)
- Delete DynamicCache init block from Granite4VisionTextModel.forward;
  GenerationMixin._prepare_cache_for_generation handles this
- Import BaseModelOutputWithPooling, capture_outputs; drop DynamicCache import

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Signed-off-by: artemspector <[email protected]>
The test actually passes — framework's recursive apply() handles QFormer
submodule weights correctly via the nn.Linear branch of _init_weights.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
artemspector and others added 2 commits May 3, 2026 10:16
…RMSNorm branches

These were missing, causing test_can_init_all_missing_weights to fail in CI
when weights initialized from meta device didn't match __init__ values.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 3, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, granite4_vision

artemspector and others added 2 commits May 3, 2026 10:42
Granite4VisionTextRMSNorm is defined in the generated file, not modular,
so referencing it by name in modular_granite4_vision.py is an undefined name.
Use attribute-based detection (has weight + variance_epsilon, not Linear) instead.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
@artem-spector
Copy link
Copy Markdown
Author

When from_pretrained(None, state_dict={}) walks submodules, our _init_weights has no isinstance branches for QFormer's internal layers (attention, embeddings, etc.), so they stay uninitialized.

That is weird, usually our initializer walks down each sub-module and initializes its weights. Unless Qformer has special buffers that need to be init, it should be fine. I can check it on Monday and do another review pass

will re-check too

you right, test_can_init_all_missing_weights works

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, last comments on style, flagged a few that weren't addressed/pushed from prev review

Approving so we will request core maintainer review to merge after fixing left-over comments

("glm_image", {"pil": "GlmImageImageProcessorPil", "torchvision": "GlmImageImageProcessor"}),
("glpn", {"pil": "GLPNImageProcessorPil", "torchvision": "GLPNImageProcessor"}),
("got_ocr2", {"pil": "GotOcr2ImageProcessorPil", "torchvision": "GotOcr2ImageProcessor"}),
("granite4_vision", {"pil": "LlavaNextImageProcessorPil", "torchvision": "LlavaNextImageProcessor"}),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this belong in image_processing_auto.py. Was it auto-generated or added manually?

Comment on lines +146 to +154
# Peek at vision hidden_size before super() to build a fully-specified qformer_config,
# avoiding any runtime field patching after super().
if isinstance(self.vision_config, dict):
vision_hidden_size = self.vision_config.get("hidden_size", 1152)
elif self.vision_config is not None:
vision_hidden_size = self.vision_config.hidden_size
else:
vision_hidden_size = 1152

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vision/text configs cannot be None, no? In other words, we create a default config if not passes, just like with qformer below. Then we won't need a magical 1152 hardcoded here

If that is about the order, you can call super() at the beginning and it will unwrap first. Also we can completely override __post_init__ by calling PreTrainedConfig.__post_init__(**kwargs) instead of super

Comment on lines +287 to +293
return (
features.view(batch, side, side, channels)
.view(batch, num_win, window_size, num_win, window_size, channels)
.transpose(2, 3)
.flatten(0, 2)
.flatten(1, 2)
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ultra nit: for sake of readability I prefer to split into several lines. Same comment for unwindowed_raster

Comment on lines +344 to +355
if isinstance(module, Granite4VisionTextRotaryEmbedding):
# Non-persistent buffers (inv_freq, original_inv_freq) are replaced with
# torch.empty_like() garbage by _move_missing_keys_from_meta_to_device.
# Recompute them here so _initialize_missing_keys restores correct values.
rope_type = module.config.rope_parameters.get("rope_type", "default")
if rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
else:
rope_init_fn = module.compute_default_rope_parameters
inv_freq, attention_scaling = rope_init_fn(module.config, module.inv_freq.device)
init.copy_(module.inv_freq, inv_freq)
init.copy_(module.original_inv_freq, inv_freq)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't push commit? 😄 Same for LN and embed modules, calling super is usually enough unless model requires special weight init

Comment on lines +368 to +369
base_model_prefix = "model"
_no_split_modules = ["Granite4VisionTextDecoderLayer"]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment, I think you have one local commit unpushed

Comment on lines +684 to +691
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also delete these and decorate merge_defaults_with_config. Also you can delete the decorator in L669 and it will be copied from llava by modular

deepstack_features = None
vision_mask = None
image_features = None
if pixel_values is not None and pixel_values.size(0) > 0:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad copy from llava-next, checking size(0) shoulnd't be needed

Comment on lines +467 to +474
concat_features = torch.cat(packed_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
if idx == 0:
vision_mask = self.get_image_token_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=concat_features
)
inputs_embeds = inputs_embeds.masked_fill(vision_mask, 0.0)
deepstack_features.append((llm_layer_idx, concat_features))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, could you add a small inline comment that model relies only on deepstack_features and assigning 0 value is on purpose

Comment on lines +772 to +779
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here re decorarots and x if x else self.x

Comment on lines +721 to +725
# Bypass nn.Module.__call__ overhead by calling the unbound forward directly.
# nn.Module.__call__ has non-trivial per-call overhead that accumulates across 40 layers × N steps.
outputs = Granite4VisionTextModel.forward(
self.language_model,
input_ids=None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks weird, is that smth from modular or simply calling self.language_model has an overhead? Do you have an intuition what could be reason, I don't think we can keep it this way

@zucchini-nlp
Copy link
Copy Markdown
Member

run-slow: granite4_vision

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants